{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bb75a2b9",
   "metadata": {},
   "source": [
    "Copyright (c) Facebook, Inc. and its affiliates.\n",
    "All rights reserved.\n",
    "\n",
    "This source code is licensed under the license found in the\n",
    "LICENSE file in the root directory of this source tree."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81a8ddb6",
   "metadata": {},
   "source": [
    "# IC-GAN\n",
    "\n",
    "\n",
    "Official Colab notebook from the paper <b>\"Instance-Conditioned GAN\"</b> by Arantxa Casanova, Marlene Careil, Jakob Verbeek, Michał Drożdżal, Adriana Romero-Soriano.\n",
    "\n",
    "This Colab provides the code to generate images with IC-GAN, with the option of further guiding the generation with captions (CLIP). \n",
    "\n",
    "Based on the Colab [WanderClip](https://j.mp/wanderclip) by Eyal Gruss [@eyaler](https://twitter.com/eyaler) [eyalgruss.com](https://eyalgruss.com)\n",
    "\n",
    "Using the work from [our repository](https://github.com/facebookresearch/ic_gan)\n",
    "\n",
    "https://github.com/openai/CLIP, Copyright (c) 2021 OpenAI\n",
    "\n",
    "https://github.com/huggingface/pytorch-pretrained-BigGAN, Copyright (c) 2019 Thomas Wolf\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9442671e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Restart after running this cell!\n",
    "\n",
    "!nvidia-smi -L\n",
    "\n",
    "import subprocess\n",
    "\n",
    "CUDA_version = [s for s in subprocess.check_output([\"nvcc\", \"--version\"]).decode(\"UTF-8\").split(\", \") if s.startswith(\"release\")][0].split(\" \")[-1]\n",
    "print(\"CUDA version:\", CUDA_version)\n",
    "\n",
    "if CUDA_version == \"10.1\":\n",
    "    torch_version_suffix = \"+cu101\"\n",
    "elif CUDA_version == \"10.2\":\n",
    "    torch_version_suffix = \"+cu102\"\n",
    "else:\n",
    "    torch_version_suffix = \"+cu111\"\n",
    "\n",
    "!pip install torch==1.8.0{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b01f51f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Setup\n",
    "!git clone https://github.com/facebookresearch/ic_gan.git\n",
    "\n",
    "%cd /content\n",
    "# Uncompress required files\n",
    "!wget https://dl.fbaipublicfiles.com/ic_gan/cc_icgan_biggan_imagenet_res256.tar.gz\n",
    "!tar -xvf cc_icgan_biggan_imagenet_res256.tar.gz\n",
    "!wget https://dl.fbaipublicfiles.com/ic_gan/icgan_biggan_imagenet_res256.tar.gz\n",
    "!tar -xvf icgan_biggan_imagenet_res256.tar.gz\n",
    "!wget https://dl.fbaipublicfiles.com/ic_gan/stored_instances.tar.gz\n",
    "!tar -xvf stored_instances.tar.gz\n",
    "\n",
    "!pip install pytorch-pretrained-biggan\n",
    "\n",
    "!git clone --depth 1 https://github.com/openai/CLIP\n",
    "!pip install ftfy\n",
    "%cd /content/CLIP\n",
    "import clip\n",
    "last_clip_model = 'ViT-B/32'\n",
    "perceptor, preprocess = clip.load(last_clip_model)\n",
    "\n",
    "import nltk\n",
    "nltk.download('wordnet')\n",
    "\n",
    "!pip install cma\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa6c7629",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Prepare functions\n",
    "from pytorch_pretrained_biggan import BigGAN, convert_to_images, one_hot_from_names, utils\n",
    "\n",
    "%cd /content/ic_gan/\n",
    "import sys\n",
    "import os\n",
    "sys.path[0] = '/content/ic_gan/inference'\n",
    "sys.path.insert(1, os.path.join(sys.path[0], \"..\"))\n",
    "import torch \n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchvision\n",
    "import sys\n",
    "torch.manual_seed(np.random.randint(sys.maxsize))\n",
    "import imageio\n",
    "from IPython.display import HTML, Image, clear_output\n",
    "from PIL import Image as Image_PIL\n",
    "from scipy.stats import truncnorm, dirichlet\n",
    "from torch import nn\n",
    "from nltk.corpus import wordnet as wn\n",
    "from base64 import b64encode\n",
    "from time import time\n",
    "import cma\n",
    "from cma.sigma_adaptation import CMAAdaptSigmaCSA, CMAAdaptSigmaTPA\n",
    "import warnings\n",
    "warnings.simplefilter(\"ignore\", cma.evolution_strategy.InjectionWarning)\n",
    "import torchvision.transforms as transforms\n",
    "import inference.utils as inference_utils\n",
    "import data_utils.utils as data_utils\n",
    "from BigGAN_PyTorch.BigGAN import Generator as generator\n",
    "import sklearn.metrics\n",
    "\n",
    "def replace_to_inplace_relu(model): #saves memory; from https://github.com/minyoungg/pix2latent/blob/master/pix2latent/model/biggan.py\n",
    "    for child_name, child in model.named_children():\n",
    "        if isinstance(child, nn.ReLU):\n",
    "            setattr(model, child_name, nn.ReLU(inplace=False))\n",
    "        else:\n",
    "            replace_to_inplace_relu(child)\n",
    "    return\n",
    "    \n",
    "def save(out,name=None, torch_format=True):\n",
    "  if torch_format:\n",
    "    with torch.no_grad():\n",
    "      out = out.cpu().numpy()\n",
    "  img = convert_to_images(out)[0]\n",
    "  if name:\n",
    "    imageio.imwrite(name, np.asarray(img))\n",
    "  return img\n",
    "\n",
    "hist = []\n",
    "def checkin(i, best_ind, total_losses, losses, regs, out, noise=None, emb=None, probs=None):\n",
    "  global sample_num, hist\n",
    "  name = None\n",
    "  if save_every and i%save_every==0:\n",
    "    name = '/content/output/frame_%05d.jpg'%sample_num\n",
    "  pil_image = save(out, name)\n",
    "  vals0 = [sample_num, i, total_losses[best_ind], losses[best_ind], regs[best_ind], np.mean(total_losses), np.mean(losses), np.mean(regs), np.std(total_losses), np.std(losses), np.std(regs)]\n",
    "  stats = 'sample=%d iter=%d best: total=%.2f cos=%.2f reg=%.3f avg: total=%.2f cos=%.2f reg=%.3f std: total=%.2f cos=%.2f reg=%.3f'%tuple(vals0)\n",
    "  vals1 = []\n",
    "  if noise is not None:\n",
    "    vals1 = [np.mean(noise), np.std(noise)]\n",
    "    stats += ' noise: avg=%.2f std=%.3f'%tuple(vals1)\n",
    "  vals2 = []\n",
    "  if emb is not None:\n",
    "    vals2 = [emb.mean(),emb.std()]\n",
    "    stats += ' emb: avg=%.2f std=%.3f'%tuple(vals2)\n",
    "  elif probs:\n",
    "    best = probs[best_ind]\n",
    "    inds = np.argsort(best)[::-1]\n",
    "    probs = np.array(probs)\n",
    "    vals2 = [ind2name[inds[0]], best[inds[0]], ind2name[inds[1]], best[inds[1]], ind2name[inds[2]], best[inds[2]], np.sum(probs >= 0.5)/pop_size,np.sum(probs >= 0.3)/pop_size,np.sum(probs >= 0.1)/pop_size]\n",
    "    stats += ' 1st=%s(%.2f) 2nd=%s(%.2f) 3rd=%s(%.2f) components: >=0.5:%.0f, >=0.3:%.0f, >=0.1:%.0f'%tuple(vals2)\n",
    "  hist.append(vals0+vals1+vals2)\n",
    "  if show_every and i%show_every==0:\n",
    "    clear_output()\n",
    "    display(pil_image)  \n",
    "  print(stats)\n",
    "  sample_num += 1\n",
    "\n",
    "def load_icgan(experiment_name, root_ = '/content'):\n",
    "  root = os.path.join(root_, experiment_name)\n",
    "  config = torch.load(\"%s/%s.pth\" %\n",
    "                      (root, \"state_dict_best0\"))['config']\n",
    "\n",
    "  config[\"weights_root\"] = root_\n",
    "  config[\"model_backbone\"] = 'biggan'\n",
    "  config[\"experiment_name\"] = experiment_name\n",
    "  # TODO: delete this line\n",
    "  G, config = inference_utils.load_model_inference(config)\n",
    "  G.cuda()\n",
    "  G.eval()\n",
    "  return G\n",
    "\n",
    "def get_output(noise_vector, input_label, input_features):  \n",
    "  if stochastic_truncation: #https://arxiv.org/abs/1702.04782\n",
    "    with torch.no_grad():\n",
    "      trunc_indices = noise_vector.abs() > 2*truncation\n",
    "      size = torch.count_nonzero(trunc_indices).cpu().numpy()\n",
    "      trunc = truncnorm.rvs(-2*truncation, 2*truncation, size=(1,size)).astype(np.float32)\n",
    "      noise_vector.data[trunc_indices] = torch.tensor(trunc, requires_grad=requires_grad, device='cuda')\n",
    "  else:\n",
    "    noise_vector = noise_vector.clamp(-2*truncation, 2*truncation)\n",
    "  if input_label is not None:\n",
    "    input_label = torch.LongTensor(input_label)\n",
    "  else:\n",
    "    input_label = None\n",
    "\n",
    "  out = model(noise_vector, input_label.cuda() if input_label is not None else None, input_features.cuda() if input_features is not None else None)\n",
    "  \n",
    "  if channels==1:\n",
    "    out = out.mean(dim=1, keepdim=True)\n",
    "    out = out.repeat(1,3,1,1)\n",
    "  return out\n",
    "\n",
    "def normality_loss(vec): #https://arxiv.org/abs/1903.00925\n",
    "    mu2 = vec.mean().square()\n",
    "    sigma2 = vec.var()\n",
    "    return mu2+sigma2-torch.log(sigma2)-1\n",
    "    \n",
    "\n",
    "def load_generative_model(gen_model, last_gen_model, experiment_name, model):\n",
    "  # Load generative model\n",
    "  if gen_model != last_gen_model:\n",
    "    model = load_icgan(experiment_name, root_ = '/content')\n",
    "    last_gen_model = gen_model\n",
    "  return model, last_gen_model\n",
    "\n",
    "def load_feature_extractor(gen_model, last_feature_extractor, feature_extractor):\n",
    "  # Load feature extractor to obtain instance features\n",
    "  feat_ext_name = 'classification' if gen_model == 'cc_icgan' else 'selfsupervised'\n",
    "  if last_feature_extractor != feat_ext_name:\n",
    "    if feat_ext_name == 'classification':\n",
    "      feat_ext_path = ''\n",
    "    else:\n",
    "      !curl -L -o /content/swav_pretrained.pth.tar -C - 'https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar' \n",
    "      feat_ext_path = '/content/swav_pretrained.pth.tar'\n",
    "    last_feature_extractor = feat_ext_name\n",
    "    feature_extractor = data_utils.load_pretrained_feature_extractor(feat_ext_path, feature_extractor = feat_ext_name)\n",
    "    feature_extractor.eval()\n",
    "  return feature_extractor, last_feature_extractor\n",
    "\n",
    "norm_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)\n",
    "norm_std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1)\n",
    "\n",
    "def preprocess_input_image(input_image_path, size): \n",
    "  pil_image = Image_PIL.open(input_image_path).convert('RGB')\n",
    "  transform_list =  transforms.Compose([data_utils.CenterCropLongEdge(), transforms.Resize((size,size)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std)])\n",
    "  tensor_image = transform_list(pil_image)\n",
    "  tensor_image = torch.nn.functional.interpolate(tensor_image.unsqueeze(0), 224, mode=\"bicubic\", align_corners=True)\n",
    "  return tensor_image\n",
    "\n",
    "def preprocess_generated_image(image): \n",
    "  transform_list =  transforms.Normalize(norm_mean, norm_std)\n",
    "  image = transform_list(image*0.5 + 0.5)\n",
    "  image = torch.nn.functional.interpolate(image, 224, mode=\"bicubic\", align_corners=True)\n",
    "  return image\n",
    "\n",
    "last_gen_model = None\n",
    "last_feature_extractor = None\n",
    "model = None\n",
    "feature_extractor = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17278e04",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Generate images with IC-GAN!\n",
    "#@markdown 1. Select type of IC-GAN model with **gen_model**: \"icgan\" is conditioned on an instance; \"cc_icgan\" is conditioned on both instance and a class index.\n",
    "#@markdown 1. Select which instance to condition on, following one of the following options:\n",
    "#@markdown     1. **input_image_instance** is the path to an input image, from either the mounted Google Drive or a manually uploaded image to \"Files\" (left part of the screen).\n",
    "#@markdown     1. **input_feature_index** write an integer from 0 to 1000. This will change the instance conditioning and therefore the style and semantics of the generated images. This will select one of the 1000 instance features pre-selected from ImageNet using k-means.\n",
    "#@markdown 1. For **class_index** (only valid for gen_model=\"cc_icgan\") write an integer from 0 to 1000. This will change the ImageNet class to condition on. Consult [this link](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a) for a correspondence between class name and indexes.\n",
    "#@markdown 1. **num_samples_ranked** (default=16) indicates the number of generated images to output in a mosaic. These generated images are the ones that scored a higher cosine similarity with the conditioning instance, out of **num_samples_total** (default=160) generated samples. Increasing \"num_samples_total\" will result in higher run times, but more generated images to choose the top \"num_samples_ranked\" from, and therefore higher chance of better image quality. Reducing \"num_samples_total\" too much could result in generated images with poorer visual quality. A ratio of 10:1 (num_samples_total:num_samples_ranked) is recommended.\n",
    "#@markdown 1. Vary **truncation** (default=0.7) from 0 to 1 to apply the [truncation trick](https://arxiv.org/abs/1809.11096). Truncation=1 will provide more diverse but possibly poorer quality images. Trucation values between 0.7 and 0.9 seem to empirically work well.\n",
    "#@markdown 1. **seed**=0 means no seed.\n",
    "\n",
    "gen_model = 'icgan' #@param ['icgan', 'cc_icgan']\n",
    "if gen_model == 'icgan':  \n",
    "  experiment_name = 'icgan_biggan_imagenet_res256'\n",
    "else:\n",
    "  experiment_name = 'cc_icgan_biggan_imagenet_res256'\n",
    "#last_gen_model = experiment_name\n",
    "size = '256'\n",
    "input_image_instance = \"\"#@param {type:\"string\"}\n",
    "input_feature_index =   3#@param {type:'integer'}\n",
    "class_index =   538#@param {type:'integer'}\n",
    "num_samples_ranked =   16#@param {type:'integer'}\n",
    "num_samples_total =    160#@param {type:'integer'}\n",
    "truncation =  0.7#@param {type:'number'}\n",
    "stochastic_truncation = False #@param {type:'boolean'}\n",
    "download_file = True #@param {type:'boolean'}\n",
    "seed =  50#@param {type:'number'}\n",
    "if seed == 0:\n",
    "  seed = None\n",
    "noise_size = 128\n",
    "class_size = 1000\n",
    "channels = 3\n",
    "batch_size = 4\n",
    "if gen_model == 'icgan':\n",
    "  class_index = None\n",
    "if 'biggan' in gen_model:\n",
    "  input_feature_index = None\n",
    "  input_image_instance = None\n",
    "\n",
    "assert(num_samples_ranked <=num_samples_total)\n",
    "import numpy as np\n",
    "state = None if not seed else np.random.RandomState(seed)\n",
    "np.random.seed(seed)\n",
    "\n",
    "feature_extractor_name = 'classification' if gen_model == 'cc_icgan' else 'selfsupervised'\n",
    "\n",
    "# Load feature extractor (outlier filtering and optionally input image feature extraction)\n",
    "feature_extractor, last_feature_extractor = load_feature_extractor(gen_model, last_feature_extractor, feature_extractor)\n",
    "# Load features \n",
    "if input_image_instance not in ['None', \"\"]:\n",
    "  print('Obtaining instance features from input image!')\n",
    "  input_feature_index = None\n",
    "  input_image_tensor = preprocess_input_image(input_image_instance, int(size))\n",
    "  print('Displaying instance conditioning:')\n",
    "  display(convert_to_images(((input_image_tensor*norm_std + norm_mean)-0.5) / 0.5)[0])\n",
    "  with torch.no_grad():\n",
    "    input_features, _ = feature_extractor(input_image_tensor.cuda())\n",
    "  input_features/=torch.linalg.norm(input_features,dim=-1, keepdims=True)\n",
    "elif input_feature_index is not None:\n",
    "  print('Selecting an instance from pre-extracted vectors!')\n",
    "  input_features = np.load('/content/stored_instances/imagenet_res'+str(size)+'_rn50_'+feature_extractor_name+'_kmeans_k1000_instance_features.npy', allow_pickle=True).item()[\"instance_features\"][input_feature_index:input_feature_index+1]\n",
    "else:\n",
    "  input_features = None\n",
    "\n",
    "# Load generative model\n",
    "model, last_gen_model = load_generative_model(gen_model, last_gen_model, experiment_name, model)\n",
    "# Prepare other variables\n",
    "name_file = '%s_class_index%s_instance_index%s'%(gen_model, str(class_index) if class_index is not None else 'None', str(input_feature_index) if input_feature_index is not None else 'None')\n",
    "\n",
    "!rm -rf /content/output\n",
    "!mkdir -p /content/output\n",
    "\n",
    "replace_to_inplace_relu(model)\n",
    "ind2name = {index: wn.of2ss('%08dn'%offset).lemma_names()[0] for offset, index in utils.IMAGENET.items()}\n",
    "\n",
    "from google.colab import files, output\n",
    "\n",
    "eps = 1e-8\n",
    "\n",
    "# Create noise, instance and class vector\n",
    "noise_vector = truncnorm.rvs(-2*truncation, 2*truncation, size=(num_samples_total, noise_size), random_state=state).astype(np.float32) #see https://github.com/tensorflow/hub/issues/214\n",
    "noise_vector = torch.tensor(noise_vector, requires_grad=False, device='cuda')\n",
    "if input_features is not None:\n",
    "  instance_vector = torch.tensor(input_features, requires_grad=False, device='cuda').repeat(num_samples_total, 1)\n",
    "else: \n",
    "  instance_vector = None\n",
    "if class_index is not None:\n",
    "  print('Conditioning on class: ', ind2name[class_index])\n",
    "  input_label = torch.LongTensor([class_index]*num_samples_total)\n",
    "else:\n",
    "  input_label = None\n",
    "if input_feature_index is not None:\n",
    "  print('Conditioning on instance with index: ', input_feature_index)\n",
    "\n",
    "size = int(size)\n",
    "all_outs, all_dists = [], []\n",
    "for i_bs in range(num_samples_total//batch_size+1):\n",
    "  start = i_bs*batch_size\n",
    "  end = min(start+batch_size, num_samples_total)\n",
    "  if start == end:\n",
    "    break\n",
    "  out = get_output(noise_vector[start:end], input_label[start:end] if input_label is not None else None, instance_vector[start:end] if instance_vector is not None else None)\n",
    "\n",
    "  if instance_vector is not None:\n",
    "    # Get features from generated images + feature extractor\n",
    "    out_ = preprocess_generated_image(out)\n",
    "    with torch.no_grad():\n",
    "      out_features, _ = feature_extractor(out_.cuda())\n",
    "    out_features/=torch.linalg.norm(out_features,dim=-1, keepdims=True)\n",
    "    dists = sklearn.metrics.pairwise_distances(\n",
    "            out_features.cpu(), instance_vector[start:end].cpu(), metric=\"euclidean\", n_jobs=-1)\n",
    "    all_dists.append(np.diagonal(dists))\n",
    "    all_outs.append(out.detach().cpu())\n",
    "  del (out)\n",
    "all_outs = torch.cat(all_outs)\n",
    "all_dists = np.concatenate(all_dists)\n",
    "\n",
    "# Order samples by distance to conditioning feature vector and select only num_samples_ranked images\n",
    "selected_idxs =np.argsort(all_dists)[:num_samples_ranked]\n",
    "#print('All distances re-ordered ', np.sort(all_dists))\n",
    "# Create figure                \n",
    "row_i, col_i, i_im = 0, 0, 0\n",
    "all_images_mosaic = np.zeros((3,size*(int(np.sqrt(num_samples_ranked))), size*(int(np.sqrt(num_samples_ranked)))))\n",
    "for j in selected_idxs:\n",
    "  all_images_mosaic[:,row_i*size:row_i*size+size, col_i*size:col_i*size+size] = all_outs[j]\n",
    "  if row_i == int(np.sqrt(num_samples_ranked))-1:\n",
    "    row_i = 0\n",
    "    if col_i == int(np.sqrt(num_samples_ranked))-1:\n",
    "      col_i = 0\n",
    "    else:\n",
    "      col_i +=1\n",
    "  else:\n",
    "    row_i+=1\n",
    "  i_im +=1\n",
    "\n",
    "name = '/content/%s_seed%i.png'%(name_file,seed if seed is not None else -1)\n",
    "pil_image = save(all_images_mosaic[np.newaxis,...],name, torch_format=False)  \n",
    "print('Displaying generated images')\n",
    "display(pil_image)\n",
    "\n",
    "if download_file:\n",
    "  files.download(name)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da5ee254",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Generate images with IC-GAN + CLIP!\n",
    "#@markdown 1. For **prompt** OpenAI suggest to use the template \"A photo of a X.\" or \"A photo of a X, a type of Y.\" [[paper]](https://cdn.openai.com/papers/Learning_Transferable_Visual_Models_From_Natural_Language_Supervision.pdf)\n",
    "#@markdown 1. Select type of IC-GAN model with **gen_model**: \"icgan\" is conditioned on an instance; \"cc_icgan\" is conditioned on both instance and a class index.\n",
    "#@markdown 1. Select which instance to condition on, following one of the following options:\n",
    "#@markdown     1. **input_image_instance** is the path to an input image, from either the mounted Google Drive or a manually uploaded image to \"Files\" (left part of the screen).\n",
    "#@markdown     1. **input_feature_index** write an integer from 0 to 1000. This will change the instance conditioning and therefore the style and semantics of the generated images. This will select one of the 1000 instance features pre-selected from ImageNet using k-means.\n",
    "#@markdown 1. For **class_index** (only valid for gen_model=\"cc_icgan\") write an integer from 0 to 1000. This will change the ImageNet class to condition on. Consult [this link](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a) for a correspondence between class name and indexes.\n",
    "#@markdown 1. Vary **truncation** from 0 to 1 to apply the [truncation trick](https://arxiv.org/abs/1809.11096). Truncation=1 will provide more diverse but possibly poorer quality images. Trucation values between 0.7 and 0.9 seem to empirically work well.\n",
    "#@markdown 4. **seed**=0 means no seed.\n",
    "prompt = 'A dragon' #@param {type:'string'}\n",
    "gen_model = 'icgan' #@param ['icgan', 'cc_icgan']\n",
    "if gen_model == 'icgan':  \n",
    "  experiment_name = 'icgan_biggan_imagenet_res256_nofeataug'\n",
    "else:\n",
    "  experiment_name = 'cc_icgan_biggan_imagenet_res256_nofeataug'\n",
    "#last_gen_model = experiment_name\n",
    "size = '256'\n",
    "input_image_instance = \"\"#@param {type:\"string\"}\n",
    "\n",
    "input_feature_index =   500#@param {type:'integer'}\n",
    "class_index =  627 #@param {type:'integer'} (only with cc_icgan)\n",
    "download_image = False #@param {type:'boolean'}\n",
    "download_video = False #@param {type:'boolean'}\n",
    "truncation = 0.7 #@param {type:'number'}\n",
    "stochastic_truncation = False #@param {type:'boolean'}\n",
    "optimizer = 'CMA-ES' #@param ['SGD','Adam','CMA-ES','CMA-ES + SGD interleaved','CMA-ES + Adam interleaved','CMA-ES + terminal SGD','CMA-ES + terminal Adam']\n",
    "pop_size = 50 #@param {type:'integer'}\n",
    "clip_model = 'ViT-B/32' #@param ['ViT-B/32','RN50','RN101','RN50x4']\n",
    "augmentations =  64#@param {type:'integer'}\n",
    "learning_rate =  0.1#@param {type:'number'}\n",
    "noise_normality_loss =  0#@param {type:'number'}\n",
    "minimum_entropy_loss = 0.0001 #@param {type:'number'}\n",
    "total_variation_loss = 0.1 #@param {type:'number'}\n",
    "iterations =  100#@param {type:'integer'}\n",
    "terminal_iterations =  100#@param {type:'integer'}\n",
    "show_every = 1 #@param {type:'integer'}\n",
    "save_every = 1 #@param {type:'integer'}\n",
    "fps =  2#@param {type:'number'}\n",
    "freeze_secs = 0 #@param {type:'number'}\n",
    "seed =  10#@param {type:'number'}\n",
    "if seed == 0:\n",
    "  seed = None\n",
    "\n",
    "softmax_temp = 1\n",
    "emb_factor = 0.067 #calculated empirically \n",
    "loss_factor = 100\n",
    "sigma0 = 0.5 #http://cma.gforge.inria.fr/cmaes_sourcecode_page.html#practical\n",
    "cma_adapt = True\n",
    "cma_diag = False\n",
    "cma_active = True\n",
    "cma_elitist = False\n",
    "noise_size = 128\n",
    "class_size = 1000\n",
    "channels = 3\n",
    "if gen_model == 'icgan':\n",
    "  class_index = None\n",
    "\n",
    "import numpy as np\n",
    "state = None if not seed else np.random.RandomState(seed)\n",
    "np.random.seed(seed)\n",
    "# Load features \n",
    "if input_image_instance not in ['None',\"\"]:\n",
    "  print('Obtaining instance features from input image!')\n",
    "  input_feature_index = None\n",
    "  feature_extractor, last_feature_extractor = load_feature_extractor(gen_model, last_feature_extractor, feature_extractor)\n",
    "  input_image_tensor = preprocess_input_image(input_image_instance, int(size))\n",
    "  input_features, _ = feature_extractor(input_image_tensor.cuda())\n",
    "  input_features/=torch.linalg.norm(input_features,dim=-1, keepdims=True)\n",
    "elif input_feature_index is not None:\n",
    "  print('Selecting an instance from pre-extracted vectors!')\n",
    "  feature_extractor_name = 'classification' if gen_model == 'cc_icgan' else 'selfsupervised'\n",
    "  input_features = np.load('/content/stored_instances/imagenet_res'+str(size)+'_rn50_'+feature_extractor_name+'_kmeans_k1000_instance_features.npy', allow_pickle=True).item()[\"instance_features\"][input_feature_index:input_feature_index+1]\n",
    "else:\n",
    "  input_features = None\n",
    "\n",
    "\n",
    "# Load generative model\n",
    "model, last_gen_model = load_generative_model(gen_model, last_gen_model, experiment_name, model)\n",
    "\n",
    "# Load CLIP model\n",
    "if clip_model != last_clip_model:\n",
    "  perceptor, preprocess = clip.load(clip_model)\n",
    "  last_clip_model = clip_model\n",
    "clip_res = perceptor.visual.input_resolution\n",
    "sideX = sideY = int(size)\n",
    "if sideX<=clip_res and sideY<=clip_res:\n",
    "  augmentations = 1\n",
    "if 'CMA' not in optimizer:\n",
    "  pop_size = 1\n",
    "\n",
    "# Prepare other variables\n",
    "name_file = '%s_%s_class_index%s_instance_index%s'%(gen_model, prompt, str(class_index) if class_index is not None else 'None', str(input_feature_index) if input_feature_index is not None else 'None')\n",
    "requires_grad = ('SGD' in optimizer or 'Adam' in optimizer) and ('terminal' not in optimizer or terminal_iterations>0)\n",
    "total_iterations = iterations + terminal_iterations*('terminal' in optimizer)\n",
    "\n",
    "!rm -rf /content/output\n",
    "!mkdir -p /content/output\n",
    "\n",
    "replace_to_inplace_relu(model)\n",
    "replace_to_inplace_relu(perceptor)\n",
    "ind2name = {index: wn.of2ss('%08dn'%offset).lemma_names()[0] for offset, index in utils.IMAGENET.items()}\n",
    "eps = 1e-8\n",
    "\n",
    "# Create noise and instance vector\n",
    "noise_vector = truncnorm.rvs(-2*truncation, 2*truncation, size=(pop_size, noise_size), random_state=state).astype(np.float32) #see https://github.com/tensorflow/hub/issues/214\n",
    "noise_vector = torch.tensor(noise_vector, requires_grad=requires_grad, device='cuda')\n",
    "if input_features is not None:\n",
    "  instance_vector = torch.tensor(input_features, requires_grad=False, device='cuda')\n",
    "else: \n",
    "  instance_vector = None\n",
    "if class_index is not None:\n",
    "  print('Conditioning on class: ', ind2name[class_index])\n",
    "if input_feature_index is not None:\n",
    "  print('Conditioning on instance with index: ', input_feature_index)\n",
    "\n",
    "# Prepare optimizer\n",
    "if requires_grad:\n",
    "  params = [noise_vector]\n",
    "  if 'SGD' in optimizer:\n",
    "    optim = torch.optim.SGD(params, lr=learning_rate, momentum=0.9)  \n",
    "  else:\n",
    "    optim = torch.optim.Adam(params, lr=learning_rate)\n",
    "\n",
    "def ascend_txt(i, grad_step=False, show_save=False):\n",
    "  global global_best_loss, global_best_iteration, global_best_noise_vector, global_best_class_vector\n",
    "  regs = []\n",
    "  losses = []\n",
    "  total_losses = []\n",
    "  best_loss = np.inf\n",
    "  global_reg = torch.tensor(0, device='cuda', dtype=torch.float32, requires_grad=grad_step)\n",
    "  if noise_normality_loss:\n",
    "    global_reg = global_reg+noise_normality_loss*normality_loss(noise_vector)\n",
    "  global_reg = loss_factor*global_reg  \n",
    "  if grad_step:\n",
    "    global_reg.backward()\n",
    "  for j in range(pop_size):\n",
    "    p_s = []\n",
    "    out = get_output(noise_vector[j:j+1], [class_index] if class_index is not None else None, instance_vector)\n",
    "    for aug in range(augmentations):\n",
    "      if sideX<=clip_res and sideY<=clip_res or augmentations==1:\n",
    "        apper = out  \n",
    "      else:\n",
    "        size = torch.randint(int(.7*sideX), int(.98*sideX), ())\n",
    "        offsetx = torch.randint(0, sideX - size, ())\n",
    "        offsety = torch.randint(0, sideX - size, ())\n",
    "        apper = out[:, :, offsetx:offsetx + size, offsety:offsety + size]\n",
    "      apper = (apper+1)/2\n",
    "      apper = nn.functional.interpolate(apper, clip_res, mode='bilinear')\n",
    "      #apper = apper.clamp(0,1)\n",
    "      p_s.append(apper)\n",
    "    into = nom(torch.cat(p_s, 0))\n",
    "    predict_clip = perceptor.encode_image(into)\n",
    "    loss = loss_factor*(1-torch.cosine_similarity(predict_clip, target_clip).mean())\n",
    "    total_loss = loss\n",
    "    regs.append(global_reg.item())\n",
    "\n",
    "    with torch.no_grad():\n",
    "      losses.append(loss.item())\n",
    "      total_losses.append(total_loss.item()+global_reg.item())\n",
    "    if total_losses[-1]<best_loss:\n",
    "      best_loss = total_losses[-1]\n",
    "      best_ind = j\n",
    "      best_out = out\n",
    "      if best_loss < global_best_loss:\n",
    "        global_best_loss = best_loss\n",
    "        global_best_iteration = i\n",
    "        with torch.no_grad():\n",
    "          global_best_noise_vector = noise_vector[best_ind]\n",
    "    if grad_step:    \n",
    "      total_loss.backward()\n",
    "\n",
    "  if grad_step:\n",
    "    optim.step()\n",
    "    optim.zero_grad()\n",
    "\n",
    "  if show_save and (save_every and i % save_every == 0 or show_every and i % show_every == 0):\n",
    "    noise = None\n",
    "    emb = None\n",
    "    with torch.no_grad():\n",
    "      noise = noise_vector.cpu().numpy()\n",
    "    checkin(i, best_ind, total_losses, losses, regs, best_out, noise, emb)  \n",
    "  return total_losses, best_ind\n",
    "\n",
    "# Obtain target CLIP representation\n",
    "tx = clip.tokenize(prompt)\n",
    "with torch.no_grad():\n",
    "  target_clip = perceptor.encode_text(tx.cuda())\n",
    "\n",
    "\n",
    "global_best_loss = np.inf\n",
    "global_best_iteration = 0\n",
    "global_best_noise_vector = None\n",
    "\n",
    "nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
    "if 'CMA' in optimizer:\n",
    "  initial_vector = np.zeros(noise_size)\n",
    "  bounds = None\n",
    "  cma_opts = {'popsize': pop_size, 'seed': np.nan, 'AdaptSigma': cma_adapt, 'CMA_diagonal': cma_diag, 'CMA_active': cma_active, 'CMA_elitist':cma_elitist, 'bounds':bounds}\n",
    "  cmaes = cma.CMAEvolutionStrategy(initial_vector, sigma0, inopts=cma_opts)\n",
    "\n",
    "sample_num = 0\n",
    "machine = !nvidia-smi -L\n",
    "start = time()\n",
    "\n",
    "# Start noise vector optimization\n",
    "for i in range(total_iterations):    \n",
    "  if 'CMA' in optimizer and i<iterations:\n",
    "    with torch.no_grad():\n",
    "      cma_results = torch.tensor(cmaes.ask(), dtype=torch.float32).cuda()\n",
    "      noise_vector.data = cma_results      \n",
    "  if requires_grad and ('terminal' not in optimizer or i>=iterations):\n",
    "    losses, best_ind = ascend_txt(i, grad_step=True, show_save='CMA' not in optimizer or i>=iterations)\n",
    "    assert noise_vector.requires_grad and noise_vector.is_leaf and (not optimize_class or class_vector.requires_grad and class_vector.is_leaf), (noise_vector.requires_grad, noise_vector.is_leaf, class_vector.requires_grad, class_vector.is_leaf)\n",
    "  if 'CMA' in optimizer and i<iterations:\n",
    "    with torch.no_grad():\n",
    "      losses, best_ind = ascend_txt(i, show_save=True)\n",
    "      if i<iterations-1:\n",
    "        vectors = noise_vector\n",
    "        cmaes.tell(vectors.cpu().numpy(), losses)\n",
    "      elif 'terminal' in optimizer and terminal_iterations:\n",
    "        pop_size = 1\n",
    "        noise_vector[0] = global_best_noise_vector\n",
    "  if save_every and i % save_every == 0 or show_every and i % show_every == 0:\n",
    "    print('took: %d secs (%.2f sec/iter) on %s. CUDA memory: %.1f GB'%(time()-start,(time()-start)/(i+1), machine[0], torch.cuda.max_memory_allocated()/1024**3))\n",
    "\n",
    "# Obtain generated image with lowest loss.\n",
    "out = get_output(global_best_noise_vector.unsqueeze(0), [class_index] if class_index is not None else None, instance_vector)\n",
    "name = '/content/%s_best_seed%i.png'%(name_file,seed if seed is not None else -1)\n",
    "pil_image = save(out,name)  \n",
    "display(pil_image)  \n",
    "print('best_loss=%.2f best_iter=%d'%(global_best_loss,global_best_iteration))\n",
    "\n",
    "if download_image:\n",
    "  from google.colab import files, output\n",
    "  files.download(name)\n",
    "\n",
    "if download_video:\n",
    "  out = '\"/content/%s_seed%i.mp4\"'%(name_file, seed if seed is not None else -1)\n",
    "  file_name = '/content/%s_seed%i.mp4'%(name_file, seed if seed is not None else -1)\n",
    "\n",
    "  with open('/content/list.txt','w') as f:\n",
    "    for i in range(sample_num):\n",
    "      f.write('file /content/output/frame_%05d.jpg\\n'%i)\n",
    "    for j in range(int(freeze_secs*fps)):\n",
    "      f.write('file /content/output/frame_%05d.jpg\\n'%i)\n",
    "  !ffmpeg -r $fps -f concat -safe 0 -i /content/list.txt -c:v libx264 -pix_fmt yuv420p -profile:v baseline -movflags +faststart -r $fps $out -y\n",
    "  with open(file_name, 'rb') as f:\n",
    "    data_url = \"data:video/mp4;base64,\" + b64encode(f.read()).decode()\n",
    "  display(HTML(\"\"\"\n",
    "    <video controls autoplay loop>\n",
    "          <source src=\"%s\" type=\"video/mp4\">\n",
    "    </video>\"\"\" % data_url))\n",
    "\n",
    "  from google.colab import files, output\n",
    "  output.eval_js('new Audio(\"https://freesound.org/data/previews/80/80921_1022651-lq.ogg\").play()')\n",
    "  files.download(file_name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
